{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.0091635 , -0.01171564, -0.03937805, -0.02005912], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAApGklEQVR4nO3df3RU9Z3/8ddMfoyEMJMGSCYpCaIgECHYgoZZW2uXlBDQlTWeo5YV7HLkyCaeaqzFdK2K7TGu7ll/dBXOnu2Ke46U1h7RQgUbQUKtETUlyw81KyzdYMkkVJqZJJqf8/n+4ZfZjoYfk4Tcz5Dn45x7TuZ+PnPnfT8nZF587i+XMcYIAADAIm6nCwAAAPg8AgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsI6jAeXpp5/WhRdeqAsuuEBFRUV6++23nSwHAABYwrGA8vOf/1yVlZV64IEH9Pvf/15z5sxRSUmJWltbnSoJAABYwuXUwwKLiop0+eWX61//9V8lSZFIRHl5ebrjjjt07733OlESAACwRLITH9rT06P6+npVVVVF17ndbhUXF6uuru4L/bu7u9Xd3R19HYlEdOLECY0fP14ul2tEagYAAENjjFF7e7tyc3Pldp/+II4jAeVPf/qT+vv7lZ2dHbM+OztbH3zwwRf6V1dXa+3atSNVHgAAOIeOHj2qSZMmnbaPIwElXlVVVaqsrIy+DoVCys/P19GjR+X1eh2sDAAAnK1wOKy8vDyNGzfujH0dCSgTJkxQUlKSWlpaYta3tLTI7/d/ob/H45HH4/nCeq/XS0ABACDBnM3pGY5cxZOamqq5c+dqx44d0XWRSEQ7duxQIBBwoiQAAGARxw7xVFZWasWKFZo3b56uuOIKPfHEE+rs7NR3vvMdp0oCAACWcCyg3HjjjTp+/Ljuv/9+BYNBXXbZZdq+ffsXTpwFAACjj2P3QRmKcDgsn8+nUCjEOSgAACSIeL6/eRYPAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1hj2gPPjgg3K5XDHLjBkzou1dXV0qLy/X+PHjlZ6errKyMrW0tAx3GQAAIIGdkxmUSy+9VM3NzdHljTfeiLbddddd2rJli1544QXV1tbq2LFjuv76689FGQAAIEEln5ONJifL7/d/YX0oFNJPf/pTbdy4UX/9138tSXr22Wc1c+ZMvfXWW5o/f/65KAcAACSYczKD8uGHHyo3N1cXXXSRli1bpqamJklSfX29ent7VVxcHO07Y8YM5efnq66u7pTb6+7uVjgcjlkAAMD5a9gDSlFRkTZs2KDt27dr3bp1OnLkiL7+9a+rvb1dwWBQqampysjIiHlPdna2gsHgKbdZXV0tn88XXfLy8oa7bAAAYJFhP8RTWloa/bmwsFBFRUWaPHmyfvGLX2jMmDGD2mZVVZUqKyujr8PhMCEFAIDz2Dm/zDgjI0OXXHKJDh06JL/fr56eHrW1tcX0aWlpGfCclZM8Ho+8Xm/MAgAAzl/nPKB0dHTo8OHDysnJ0dy5c5WSkqIdO3ZE2xsbG9XU1KRAIHCuSwEAAAli2A/xfO9739O1116ryZMn69ixY3rggQeUlJSkm2++WT6fTytXrlRlZaUyMzPl9Xp1xx13KBAIcAUPAACIGvaA8tFHH+nmm2/Wxx9/rIkTJ+prX/ua3nrrLU2cOFGS9Pjjj8vtdqusrEzd3d0qKSnRM888M9xlAACABOYyxhini4hXOByWz+dTKBTifBQAABJEPN/fPIsHAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGCduAPK7t27de211yo3N1cul0svvfRSTLsxRvfff79ycnI0ZswYFRcX68MPP4zpc+LECS1btkxer1cZGRlauXKlOjo6hrQjAADg/BF3QOns7NScOXP09NNPD9j+6KOP6qmnntL69eu1Z88ejR07ViUlJerq6or2WbZsmQ4ePKiamhpt3bpVu3fv1qpVqwa/FwAA4LziMsaYQb/Z5dLmzZu1dOlSSZ/NnuTm5uruu+/W9773PUlSKBRSdna2NmzYoJtuuknvv/++CgoK9M4772jevHmSpO3bt2vx4sX66KOPlJube8bPDYfD8vl8CoVC8nq9gy0fAACMoHi+v4f1HJQjR44oGAyquLg4us7n86moqEh1dXWSpLq6OmVkZETDiSQVFxfL7XZrz549A263u7tb4XA4ZgEAAOevYQ0owWBQkpSdnR2zPjs7O9oWDAaVlZUV056cnKzMzMxon8+rrq6Wz+eLLnl5ecNZNgAAsExCXMVTVVWlUCgUXY4ePep0SQAA4Bwa1oDi9/slSS0tLTHrW1paom1+v1+tra0x7X19fTpx4kS0z+d5PB55vd6YBQAAnL+GNaBMmTJFfr9fO3bsiK4Lh8Pas2ePAoGAJCkQCKitrU319fXRPjt37lQkElFRUdFwlgMAABJUcrxv6Ojo0KFDh6Kvjxw5ooaGBmVmZio/P1933nmnfvzjH2vatGmaMmWKfvjDHyo3Nzd6pc/MmTO1aNEi3XbbbVq/fr16e3tVUVGhm2666ayu4AEAAOe/uAPKu+++q29+85vR15WVlZKkFStWaMOGDfr+97+vzs5OrVq1Sm1tbfra176m7du364ILLoi+5/nnn1dFRYUWLFggt9utsrIyPfXUU8OwOwAA4HwwpPugOIX7oAAAkHgcuw8KAADAcCCgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwTtwBZffu3br22muVm5srl8ull156Kab91ltvlcvlilkWLVoU0+fEiRNatmyZvF6vMjIytHLlSnV0dAxpRwAAwPkj7oDS2dmpOXPm6Omnnz5ln0WLFqm5uTm6/OxnP4tpX7ZsmQ4ePKiamhpt3bpVu3fv1qpVq+KvHgAAnJeS431DaWmpSktLT9vH4/HI7/cP2Pb+++9r+/bteueddzRv3jxJ0k9+8hMtXrxY//zP/6zc3Nx4SwIAAOeZc3IOyq5du5SVlaXp06dr9erV+vjjj6NtdXV1ysjIiIYTSSouLpbb7daePXsG3F53d7fC4XDMAgAAzl/DHlAWLVqk//zP/9SOHTv0T//0T6qtrVVpaan6+/slScFgUFlZWTHvSU5OVmZmpoLB4IDbrK6uls/niy55eXnDXTYAALBI3Id4zuSmm26K/jx79mwVFhbq4osv1q5du7RgwYJBbbOqqkqVlZXR1+FwmJACAMB57JxfZnzRRRdpwoQJOnTokCTJ7/ertbU1pk9fX59OnDhxyvNWPB6PvF5vzAIAAM5f5zygfPTRR/r444+Vk5MjSQoEAmpra1N9fX20z86dOxWJRFRUVHSuywEAAAkg7kM8HR0d0dkQSTpy5IgaGhqUmZmpzMxMrV27VmVlZfL7/Tp8+LC+//3va+rUqSopKZEkzZw5U4sWLdJtt92m9evXq7e3VxUVFbrpppu4ggcAAEiSXMYYE88bdu3apW9+85tfWL9ixQqtW7dOS5cu1d69e9XW1qbc3FwtXLhQP/rRj5SdnR3te+LECVVUVGjLli1yu90qKyvTU089pfT09LOqIRwOy+fzKRQKcbgHAIAEEc/3d9wBxQYEFAAAEk883988iwcAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArBP3wwIBYKj+/IcGHX//t6ftk559sXK/uniEKgJgGwIKgBHX3f6xQk37T9vH5U6SifTL5U4aoaoA2IRDPACsZCL9ivT3OV0GAIcQUABYyZiITKTf6TIAOISAAsBKJtIvwwwKMGoRUADYKcIMCjCaEVAAWMmYiCIRZlCA0YqAAsBKJtLPDAowihFQAFjJRCKcgwKMYgQUAHbiKh5gVCOgALBShKt4gFGNgAJgxKVnTdGYzC+ftk9XW7Pag4dGqCIAtiGgABhx7qQUuZLO8KQNY2RMZGQKAmAdAgqAEedyJ8nl4hk7AE6NgAJgxLmSkngIIIDTIqAAGHEud5Jcbv78ADg1/kIAGHEudzIzKABOi4ACYMR9NoNCQAFwagQUACOOQzwAzoS/EABGHFfxADiTuAJKdXW1Lr/8co0bN05ZWVlaunSpGhsbY/p0dXWpvLxc48ePV3p6usrKytTS0hLTp6mpSUuWLFFaWpqysrJ0zz33qK+PO0YCo4U7iXNQAJxeXAGltrZW5eXleuutt1RTU6Pe3l4tXLhQnZ2d0T533XWXtmzZohdeeEG1tbU6duyYrr/++mh7f3+/lixZop6eHr355pt67rnntGHDBt1///3Dt1cA7OZyn+UhHiNjzDkvB4B9XGYI//qPHz+urKws1dbW6qqrrlIoFNLEiRO1ceNG3XDDDZKkDz74QDNnzlRdXZ3mz5+vbdu26ZprrtGxY8eUnZ0tSVq/fr3WrFmj48ePKzU19YyfGw6H5fP5FAqF5PV6B1s+AAf9z86f6uMP95y2T+7ca5X71cXMtgDniXi+v4d0DkooFJIkZWZmSpLq6+vV29ur4uLiaJ8ZM2YoPz9fdXV1kqS6ujrNnj07Gk4kqaSkROFwWAcPHhzwc7q7uxUOh2MWAOc/E+njdvfAKDXogBKJRHTnnXfqyiuv1KxZsyRJwWBQqampysjIiOmbnZ2tYDAY7fOX4eRk+8m2gVRXV8vn80WXvLy8wZYNIIFE+vtkIgQUYDQadEApLy/XgQMHtGnTpuGsZ0BVVVUKhULR5ejRo+f8MwE4z/T3ScygAKPSGR4nOrCKigpt3bpVu3fv1qRJk6Lr/X6/enp61NbWFjOL0tLSIr/fH+3z9ttvx2zv5FU+J/t8nsfjkcfjGUypABJYpJ9DPMBoFdcMijFGFRUV2rx5s3bu3KkpU6bEtM+dO1cpKSnasWNHdF1jY6OampoUCAQkSYFAQPv371dra2u0T01NjbxerwoKCoayLwDOMybCIR5gtIprBqW8vFwbN27Uyy+/rHHjxkXPGfH5fBozZox8Pp9WrlypyspKZWZmyuv16o477lAgEND8+fMlSQsXLlRBQYFuueUWPfroowoGg7rvvvtUXl7OLAmAGBziAUavuALKunXrJElXX311zPpnn31Wt956qyTp8ccfl9vtVllZmbq7u1VSUqJnnnkm2jcpKUlbt27V6tWrFQgENHbsWK1YsUIPPfTQ0PYEwHmHq3iA0WtI90FxCvdBARLf2dwHxZdfqMlfu1meceNHqCoA59KI3QcFAAYrKTVNcp3+T1BfV4cifT0jVBEAmxBQADjCl3epki9IP22fztb/UU/Hn0eoIgA2IaAAcIQ7KUUul8vpMgBYioACwBGupBSJgALgFAgoABzhTkqW6wznoAAYvfjrAMARruQUScygABgYAQWAIz6bQSGgABgYAQWAI1xJqWe8zBjA6MVfBwCOcCclc5IsgFMioABwhDuZy4wBnBoBBYAj3FxmDOA0CCgAHOFyJ8l1FlfxGBNRAj4yDMAQEVAAWM308yweYDQioACwWn9fr8QMCjDqEFAAWO2zpxkTUIDRhoACwGqmr5dzUIBRiIACwGoRzkEBRiUCCgCrRTgHBRiVCCgArBbp7xXnoACjDwEFgNVMXy/xBBiFCCgArBbp7+EQDzAKEVAAOCY9Z5p0hrvJtgcPy/T3jkxBAKxBQAHgmPSsKWd8Hk/Xn48pEukfoYoA2IKAAsAx7uRUp0sAYCkCCgDHuJM9cvFEYwADIKAAcIw7OcXpEgBYioACwDEc4gFwKgQUAI4hoAA4FQIKAMe4Uzw602XGAEYnAgoAxzCDAuBUCCgAHJOUnMoECoABxRVQqqurdfnll2vcuHHKysrS0qVL1djYGNPn6quvlsvlilluv/32mD5NTU1asmSJ0tLSlJWVpXvuuUd9fX1D3xsACYUZFACnkhxP59raWpWXl+vyyy9XX1+ffvCDH2jhwoV67733NHbs2Gi/2267TQ899FD0dVpaWvTn/v5+LVmyRH6/X2+++aaam5u1fPlypaSk6OGHHx6GXQKQMFxn938k098rYwz3TAFGkbgCyvbt22Neb9iwQVlZWaqvr9dVV10VXZ+Wlia/3z/gNn7zm9/ovffe02uvvabs7Gxddtll+tGPfqQ1a9bowQcfVGoq/6MCECvS2+N0CQBG2JDOQQmFQpKkzMzMmPXPP/+8JkyYoFmzZqmqqkqffPJJtK2urk6zZ89WdnZ2dF1JSYnC4bAOHjw44Od0d3crHA7HLABGj/6+bqdLADDC4ppB+UuRSER33nmnrrzySs2aNSu6/tvf/rYmT56s3Nxc7du3T2vWrFFjY6NefPFFSVIwGIwJJ5Kir4PB4ICfVV1drbVr1w62VAAJjhkUYPQZdEApLy/XgQMH9MYbb8SsX7VqVfTn2bNnKycnRwsWLNDhw4d18cUXD+qzqqqqVFlZGX0dDoeVl5c3uMIBJJwIMyjAqDOoQzwVFRXaunWrXn/9dU2aNOm0fYuKiiRJhw4dkiT5/X61tLTE9Dn5+lTnrXg8Hnm93pgFwOgR6WMGBRht4gooxhhVVFRo8+bN2rlzp6ZMmXLG9zQ0NEiScnJyJEmBQED79+9Xa2trtE9NTY28Xq8KCgriKQfAKNFPQAFGnbgO8ZSXl2vjxo16+eWXNW7cuOg5Iz6fT2PGjNHhw4e1ceNGLV68WOPHj9e+fft011136aqrrlJhYaEkaeHChSooKNAtt9yiRx99VMFgUPfdd5/Ky8vl8XiGfw8BJLxIL4d4gNEmrhmUdevWKRQK6eqrr1ZOTk50+fnPfy5JSk1N1WuvvaaFCxdqxowZuvvuu1VWVqYtW7ZEt5GUlKStW7cqKSlJgUBAf/d3f6fly5fH3DcFAP4Sh3iA0SeuGRRjzGnb8/LyVFtbe8btTJ48Wa+88ko8Hw1gFCOgAKMPz+IB4KiJM79xxj6tB3ed+0IAWIWAAsBRKWlnviqvv7drBCoBYBMCCgBH8cBAAAMhoABwVBIBBcAACCgAHMUMCoCBEFAAOMqdzP2PAHwRAQWAo9wpzKAA+CICCgBHMYMCYCAEFACOSmIGBcAACCgAHMVJsgAGQkAB4CgCCoCBEFAAOMblcklynVVfnscDjC4EFAAJwBBQgFGGgAIgIRBQgNGFgALAfoaAAow2BBQACSHS1+10CQBGEAEFQAIwivT1Ol0EgBFEQAGQEDjEA4wuBBQACYGAAowuyU4XACDx9fX1Dfq9/f1n8V4j9fV0DelzJMntdsvt5v9lQCIgoAAYsunTp6upqWlQ7x3vHaNN91+vsRec+o6yvX29qrxjtTa9fnCwJUqStmzZokWLFg1pGwBGBv+VADBkfX19g17CnZ+q5t3/Oe32k5PcKi26eEif09fXJ2PMCI0IgKFiBgWAo4yRunv+79DNn3pyFeqbqIiSNMbdoYmpTfK4uxysEIATCCgAHGWMUXdfvyTp0Cdf1Uddl6grMlZGLqW4evRR13R91fsbh6sEMNI4xAPAUZ/NoPTryKezdfiTy/RpxCujJElu9ZoL9Oe+HL3Zdr0iJsnpUgGMIAIKAEcZGf2x068POucrcopJ3U8j6aoLLR3ZwgA4ioACwFHGSN29fZJcp+nlkjltO4DzDQEFgKOMMeru7Xe6DACWIaAAcJQxUlfP0G7ABuD8Q0AB4Cgjo3Qd1dS0d+VSZMA+Ka4uFfm2jHBlAJwUV0BZt26dCgsL5fV65fV6FQgEtG3btmh7V1eXysvLNX78eKWnp6usrEwtLS0x22hqatKSJUuUlpamrKws3XPPPUO+fTWAxGWM1Nvbo6ljfq8Lx+xXquuT/x9UjJJcPUpPOqGrvvRzpbi6nS4VwAiK6z4okyZN0iOPPKJp06bJGKPnnntO1113nfbu3atLL71Ud911l37961/rhRdekM/nU0VFha6//nr97ne/kyT19/dryZIl8vv9evPNN9Xc3Kzly5crJSVFDz/88DnZQQD2O/Zxu17+3QeSPlBL94X6c59f/SZZaUkh5XoO6xX3J2r9c6fTZQIYQS4zxHs/Z2Zm6rHHHtMNN9ygiRMnauPGjbrhhhskSR988IFmzpypuro6zZ8/X9u2bdM111yjY8eOKTs7W5K0fv16rVmzRsePH1dq6qmfxfGXwuGwfD6fbr311rN+D4BzZ+PGjero6HC6jDMqLS1VXl6e02UAo1ZPT482bNigUCgkr9d72r6DvpNsf3+/XnjhBXV2dioQCKi+vl69vb0qLi6O9pkxY4by8/OjAaWurk6zZ8+OhhNJKikp0erVq3Xw4EF95StfGfCzuru71d39f9O74XBYknTLLbcoPT19sLsAYJj86le/SoiAUlJSokAg4HQZwKjV0dGhDRs2nFXfuAPK/v37FQgE1NXVpfT0dG3evFkFBQVqaGhQamqqMjIyYvpnZ2crGAxKkoLBYEw4Odl+su1UqqurtXbt2i+snzdv3hkTGIBzL1FmMi+55BJdccUVTpcBjFonJxjORtxX8UyfPl0NDQ3as2ePVq9erRUrVui9996LdzNxqaqqUigUii5Hjx49p58HAACcFfcMSmpqqqZOnSpJmjt3rt555x09+eSTuvHGG9XT06O2traYWZSWlhb5/X5Jkt/v19tvvx2zvZNX+ZzsMxCPxyOPxxNvqQAAIEEN+T4okUhE3d3dmjt3rlJSUrRjx45oW2Njo5qamqLHfAOBgPbv36/W1tZon5qaGnm9XhUUFAy1FAAAcJ6IawalqqpKpaWlys/PV3t7uzZu3Khdu3bp1Vdflc/n08qVK1VZWanMzEx5vV7dcccdCgQCmj9/viRp4cKFKigo0C233KJHH31UwWBQ9913n8rLy5khAQAAUXEFlNbWVi1fvlzNzc3y+XwqLCzUq6++qm9961uSpMcff1xut1tlZWXq7u5WSUmJnnnmmej7k5KStHXrVq1evVqBQEBjx47VihUr9NBDDw3vXgEAgIQ25PugOOHkfVDO5jpqAOfe5MmT1dTU5HQZZ/TKK6+otLTU6TKAUSue72+exQMAAKxDQAEAANYhoAAAAOsQUAAAgHUG/SweADippKREx48fd7qMM/r8ozYA2IuAAmDI/u3f/s3pEgCcZzjEAwAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWCeugLJu3ToVFhbK6/XK6/UqEAho27Zt0farr75aLpcrZrn99ttjttHU1KQlS5YoLS1NWVlZuueee9TX1zc8ewMAAM4LyfF0njRpkh555BFNmzZNxhg999xzuu6667R3715deumlkqTbbrtNDz30UPQ9aWlp0Z/7+/u1ZMkS+f1+vfnmm2pubtby5cuVkpKihx9+eJh2CQAAJDqXMcYMZQOZmZl67LHHtHLlSl199dW67LLL9MQTTwzYd9u2bbrmmmt07NgxZWdnS5LWr1+vNWvW6Pjx40pNTT2rzwyHw/L5fAqFQvJ6vUMpHwAAjJB4vr8HfQ5Kf3+/Nm3apM7OTgUCgej6559/XhMmTNCsWbNUVVWlTz75JNpWV1en2bNnR8OJJJWUlCgcDuvgwYOn/Kzu7m6Fw+GYBQAAnL/iOsQjSfv371cgEFBXV5fS09O1efNmFRQUSJK+/e1va/LkycrNzdW+ffu0Zs0aNTY26sUXX5QkBYPBmHAiKfo6GAye8jOrq6u1du3aeEsFAAAJKu6AMn36dDU0NCgUCumXv/ylVqxYodraWhUUFGjVqlXRfrNnz1ZOTo4WLFigw4cP6+KLLx50kVVVVaqsrIy+DofDysvLG/T2AACA3eI+xJOamqqpU6dq7ty5qq6u1pw5c/Tkk08O2LeoqEiSdOjQIUmS3+9XS0tLTJ+Tr/1+/yk/0+PxRK8cOrkAAIDz15DvgxKJRNTd3T1gW0NDgyQpJydHkhQIBLR//361trZG+9TU1Mjr9UYPEwEAAMR1iKeqqkqlpaXKz89Xe3u7Nm7cqF27dunVV1/V4cOHtXHjRi1evFjjx4/Xvn37dNddd+mqq65SYWGhJGnhwoUqKCjQLbfcokcffVTBYFD33XefysvL5fF4zskOAgCAxBNXQGltbdXy5cvV3Nwsn8+nwsJCvfrqq/rWt76lo0eP6rXXXtMTTzyhzs5O5eXlqaysTPfdd1/0/UlJSdq6datWr16tQCCgsWPHasWKFTH3TQEAABjyfVCcwH1QAABIPCNyHxQAAIBzhYACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFgn2ekCBsMYI0kKh8MOVwIAAM7Wye/tk9/jp5OQAaW9vV2SlJeX53AlAAAgXu3t7fL5fKft4zJnE2MsE4lE1NjYqIKCAh09elRer9fpkhJWOBxWXl4e4zgMGMvhw1gOD8Zx+DCWw8MYo/b2duXm5srtPv1ZJgk5g+J2u/XlL39ZkuT1evllGQaM4/BhLIcPYzk8GMfhw1gO3ZlmTk7iJFkAAGAdAgoAALBOwgYUj8ejBx54QB6Px+lSEhrjOHwYy+HDWA4PxnH4MJYjLyFPkgUAAOe3hJ1BAQAA5y8CCgAAsA4BBQAAWIeAAgAArJOQAeXpp5/WhRdeqAsuuEBFRUV6++23nS7JOrt379a1116r3NxcuVwuvfTSSzHtxhjdf//9ysnJ0ZgxY1RcXKwPP/wwps+JEye0bNkyeb1eZWRkaOXKlero6BjBvXBedXW1Lr/8co0bN05ZWVlaunSpGhsbY/p0dXWpvLxc48ePV3p6usrKytTS0hLTp6mpSUuWLFFaWpqysrJ0zz33qK+vbyR3xVHr1q1TYWFh9CZXgUBA27Zti7YzhoP3yCOPyOVy6c4774yuYzzPzoMPPiiXyxWzzJgxI9rOODrMJJhNmzaZ1NRU8x//8R/m4MGD5rbbbjMZGRmmpaXF6dKs8sorr5h//Md/NC+++KKRZDZv3hzT/sgjjxifz2deeukl81//9V/mb/7mb8yUKVPMp59+Gu2zaNEiM2fOHPPWW2+Z3/72t2bq1Knm5ptvHuE9cVZJSYl59tlnzYEDB0xDQ4NZvHixyc/PNx0dHdE+t99+u8nLyzM7duww7777rpk/f775q7/6q2h7X1+fmTVrlikuLjZ79+41r7zyipkwYYKpqqpyYpcc8atf/cr8+te/Nv/93/9tGhsbzQ9+8AOTkpJiDhw4YIxhDAfr7bffNhdeeKEpLCw03/3ud6PrGc+z88ADD5hLL73UNDc3R5fjx49H2xlHZyVcQLniiitMeXl59HV/f7/Jzc011dXVDlZlt88HlEgkYvx+v3nsscei69ra2ozH4zE/+9nPjDHGvPfee0aSeeedd6J9tm3bZlwul/njH/84YrXbprW11UgytbW1xpjPxi0lJcW88MIL0T7vv/++kWTq6uqMMZ+FRbfbbYLBYLTPunXrjNfrNd3d3SO7Axb50pe+ZP793/+dMRyk9vZ2M23aNFNTU2O+8Y1vRAMK43n2HnjgATNnzpwB2xhH5yXUIZ6enh7V19eruLg4us7tdqu4uFh1dXUOVpZYjhw5omAwGDOOPp9PRUVF0XGsq6tTRkaG5s2bF+1TXFwst9utPXv2jHjNtgiFQpKkzMxMSVJ9fb16e3tjxnLGjBnKz8+PGcvZs2crOzs72qekpEThcFgHDx4cwert0N/fr02bNqmzs1OBQIAxHKTy8nItWbIkZtwkfifj9eGHHyo3N1cXXXSRli1bpqamJkmMow0S6mGBf/rTn9Tf3x/zyyBJ2dnZ+uCDDxyqKvEEg0FJGnAcT7YFg0FlZWXFtCcnJyszMzPaZ7SJRCK68847deWVV2rWrFmSPhun1NRUZWRkxPT9/FgONNYn20aL/fv3KxAIqKurS+np6dq8ebMKCgrU0NDAGMZp06ZN+v3vf6933nnnC238Tp69oqIibdiwQdOnT1dzc7PWrl2rr3/96zpw4ADjaIGECiiAk8rLy3XgwAG98cYbTpeSkKZPn66GhgaFQiH98pe/1IoVK1RbW+t0WQnn6NGj+u53v6uamhpdcMEFTpeT0EpLS6M/FxYWqqioSJMnT9YvfvELjRkzxsHKICXYVTwTJkxQUlLSF86ibmlpkd/vd6iqxHNyrE43jn6/X62trTHtfX19OnHixKgc64qKCm3dulWvv/66Jk2aFF3v9/vV09Ojtra2mP6fH8uBxvpk22iRmpqqqVOnau7cuaqurtacOXP05JNPMoZxqq+vV2trq7761a8qOTlZycnJqq2t1VNPPaXk5GRlZ2cznoOUkZGhSy65RIcOHeL30gIJFVBSU1M1d+5c7dixI7ouEolox44dCgQCDlaWWKZMmSK/3x8zjuFwWHv27ImOYyAQUFtbm+rr66N9du7cqUgkoqKiohGv2SnGGFVUVGjz5s3auXOnpkyZEtM+d+5cpaSkxIxlY2OjmpqaYsZy//79MYGvpqZGXq9XBQUFI7MjFopEIuru7mYM47RgwQLt379fDQ0N0WXevHlatmxZ9GfGc3A6Ojp0+PBh5eTk8HtpA6fP0o3Xpk2bjMfjMRs2bDDvvfeeWbVqlcnIyIg5ixqfneG/d+9es3fvXiPJ/Mu//IvZu3ev+d///V9jzGeXGWdkZJiXX37Z7Nu3z1x33XUDXmb8la98xezZs8e88cYbZtq0aaPuMuPVq1cbn89ndu3aFXMp4ieffBLtc/vtt5v8/Hyzc+dO8+6775pAIGACgUC0/eSliAsXLjQNDQ1m+/btZuLEiaPqUsR7773X1NbWmiNHjph9+/aZe++917hcLvOb3/zGGMMYDtVfXsVjDON5tu6++26za9cuc+TIEfO73/3OFBcXmwkTJpjW1lZjDOPotIQLKMYY85Of/MTk5+eb1NRUc8UVV5i33nrL6ZKs8/rrrxtJX1hWrFhhjPnsUuMf/vCHJjs723g8HrNgwQLT2NgYs42PP/7Y3HzzzSY9Pd14vV7zne98x7S3tzuwN84ZaAwlmWeffTba59NPPzX/8A//YL70pS+ZtLQ087d/+7emubk5Zjt/+MMfTGlpqRkzZoyZMGGCufvuu01vb+8I741z/v7v/95MnjzZpKammokTJ5oFCxZEw4kxjOFQfT6gMJ5n58YbbzQ5OTkmNTXVfPnLXzY33nijOXToULSdcXSWyxhjnJm7AQAAGFhCnYMCAABGBwIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKzz/wD+oVOFg7XuBgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.6136, 0.3864],\n",
       "        [0.6010, 0.3990]], grad_fn=<SoftmaxBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model_action = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "model_action(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1294, -0.2356],\n",
       "        [ 0.0668, -0.2346]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_value1 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "model_value2 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "model_value_next1 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "model_value_next2 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "model_value_next1.load_state_dict(model_value1.state_dict())\n",
    "model_value_next2.load_state_dict(model_value2.state_dict())\n",
    "\n",
    "model_value1(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "    prob = model_action(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((203, 0), 203)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2238/1417522126.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.0860,  0.8179, -0.1682, -1.3635],\n",
       "         [ 0.1034,  0.4356, -0.1653, -0.8932],\n",
       "         [ 0.0267,  0.1994,  0.0031, -0.2416],\n",
       "         [ 0.0019, -0.9420,  0.0940,  1.5556],\n",
       "         [-0.0732, -0.7893, -0.0100,  1.1274]]),\n",
       " tensor([[0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[ 1.0240e-01,  6.2527e-01, -1.9550e-01, -1.1278e+00],\n",
       "         [ 1.1211e-01,  6.3251e-01, -1.8312e-01, -1.2330e+00],\n",
       "         [ 3.0643e-02,  3.9451e-01, -1.7803e-03, -5.3336e-01],\n",
       "         [-1.6931e-02, -1.1381e+00,  1.2513e-01,  1.8760e+00],\n",
       "         [-8.8971e-02, -5.9402e-01,  1.2518e-02,  8.3156e-01]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state[:5], action[:5], reward[:5], next_state[:5], over[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_update(model, model_next):\n",
    "    for param, param_next in zip(model.parameters(), model_next.parameters()):\n",
    "        #以一个小的比例更新\n",
    "        value = param_next.data * 0.995 + param.data * 0.005\n",
    "        param_next.data.copy_(value)\n",
    "\n",
    "\n",
    "soft_update(torch.nn.Linear(4, 64), torch.nn.Linear(4, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-4.6052, requires_grad=True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "#这也是一个可学习的参数\n",
    "alpha = torch.tensor(math.log(0.01))\n",
    "alpha.requires_grad = True\n",
    "\n",
    "alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #计算动作的概率\n",
    "    #[b, 4] -> [b, 2]\n",
    "    prob = model_action(next_state)\n",
    "\n",
    "    #计算动作的熵\n",
    "    #[b, 2]\n",
    "    entropy = prob * torch.log(prob + 1e-8)\n",
    "\n",
    "    #所有动作的熵求和\n",
    "    #[b, 2] -> [b, 1]\n",
    "    entropy = -entropy.sum(dim=1, keepdim=True)\n",
    "\n",
    "    #评估next_state的价值\n",
    "    #[b, 4] -> [b, 2]\n",
    "    target1 = model_value_next1(next_state)\n",
    "    target2 = model_value_next2(next_state)\n",
    "\n",
    "    #取价值小的,这是出于稳定性考虑\n",
    "    #[b, 2]\n",
    "    target = torch.min(target1, target2)\n",
    "\n",
    "    #求target期望\n",
    "    #[b, 2] * [b, 2] -> [b, 2]\n",
    "    target = (prob * target)\n",
    "    #[b, 2] -> [b, 1]\n",
    "    target = target.sum(dim=1, keepdim=True)\n",
    "\n",
    "    #exp和log互为反操作,这里是把alpha还原了\n",
    "    #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target = target + alpha.exp() * entropy\n",
    "\n",
    "    #[b, 2]\n",
    "    target *= 0.98\n",
    "    target *= (1 - over)\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.1237, grad_fn=<NegBackward0>),\n",
       " tensor([[0.6922],\n",
       "         [0.6924],\n",
       "         [0.6902],\n",
       "         [0.6408],\n",
       "         [0.6512],\n",
       "         [0.6260],\n",
       "         [0.6913],\n",
       "         [0.6875],\n",
       "         [0.6764],\n",
       "         [0.6844],\n",
       "         [0.6903],\n",
       "         [0.6904],\n",
       "         [0.6920],\n",
       "         [0.6134],\n",
       "         [0.6871],\n",
       "         [0.6837],\n",
       "         [0.6923],\n",
       "         [0.6918],\n",
       "         [0.6847],\n",
       "         [0.6916],\n",
       "         [0.6641],\n",
       "         [0.6867],\n",
       "         [0.6615],\n",
       "         [0.6790],\n",
       "         [0.6923],\n",
       "         [0.6892],\n",
       "         [0.6867],\n",
       "         [0.6910],\n",
       "         [0.6924],\n",
       "         [0.6920],\n",
       "         [0.6533],\n",
       "         [0.6705],\n",
       "         [0.6923],\n",
       "         [0.6337],\n",
       "         [0.6925],\n",
       "         [0.6551],\n",
       "         [0.6748],\n",
       "         [0.6792],\n",
       "         [0.6924],\n",
       "         [0.6929],\n",
       "         [0.6908],\n",
       "         [0.6757],\n",
       "         [0.6873],\n",
       "         [0.6914],\n",
       "         [0.6478],\n",
       "         [0.6922],\n",
       "         [0.6573],\n",
       "         [0.6919],\n",
       "         [0.6904],\n",
       "         [0.6921],\n",
       "         [0.6927],\n",
       "         [0.6494],\n",
       "         [0.6909],\n",
       "         [0.6902],\n",
       "         [0.6922],\n",
       "         [0.6856],\n",
       "         [0.6862],\n",
       "         [0.6878],\n",
       "         [0.6908],\n",
       "         [0.6778],\n",
       "         [0.6362],\n",
       "         [0.6775],\n",
       "         [0.6920],\n",
       "         [0.6922]], grad_fn=<NegBackward0>))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss_action(state):\n",
    "    #计算动作的概率\n",
    "    #[b, 4] -> [b, 2]\n",
    "    prob = model_action(state)\n",
    "\n",
    "    #计算动作的熵\n",
    "    #[b, 2]\n",
    "    entropy = prob * (prob + 1e-8).log()\n",
    "\n",
    "    #所有动作的熵求和\n",
    "    #[b, 2] -> [b, 1]\n",
    "    entropy = -entropy.sum(dim=1, keepdim=True)\n",
    "\n",
    "    #评估state的价值\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value1 = model_value1(state)\n",
    "    value2 = model_value2(state)\n",
    "\n",
    "    #取价值小的,出于稳定性考虑\n",
    "    #[b, 2]\n",
    "    value = torch.min(value1, value2)\n",
    "\n",
    "    #按动作的概率对价值加权\n",
    "    #[b, 2] * [b, 2] -> [b, 2]\n",
    "    value *= prob\n",
    "\n",
    "    #所有动作的价值求和\n",
    "    #[b, 2] -> [b, 1]\n",
    "    value = value.sum(dim=1, keepdim=True)\n",
    "\n",
    "    #这里的操作是在target上加上了动作的熵,这个值越大越好\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    loss_action = value + alpha.exp() * entropy\n",
    "\n",
    "    #因为是计算loss,所以对这个值符号取反\n",
    "    return -loss_action.mean(), entropy\n",
    "\n",
    "\n",
    "get_loss_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 440 0.0024749916046857834 12.6\n",
      "20 6586 2.639788181113545e-05 188.3\n",
      "40 10000 3.0447399694821797e-06 199.8\n",
      "60 10000 4.20264484546351e-07 200.0\n",
      "80 10000 6.164231791672137e-08 195.0\n",
      "100 10000 8.757042202489629e-09 135.5\n",
      "120 10000 1.3429511946938533e-09 171.9\n",
      "140 10000 3.3358191209309496e-10 200.0\n",
      "160 10000 1.5044837597955052e-10 200.0\n",
      "180 10000 9.28849983039548e-11 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer_action = torch.optim.Adam(model_action.parameters(), lr=1e-3)\n",
    "    optimizer_value1 = torch.optim.Adam(model_value1.parameters(), lr=1e-2)\n",
    "    optimizer_value2 = torch.optim.Adam(model_value2.parameters(), lr=1e-2)\n",
    "\n",
    "    #alpha也是要更新的参数,所以这里要定义优化器\n",
    "    optimizer_alpha = torch.optim.Adam([alpha], lr=1e-2)\n",
    "\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算target,这个target里已经考虑了动作的熵\n",
    "            #[b, 1]\n",
    "            target = get_target(reward, next_state, over)\n",
    "            target = target.detach()\n",
    "\n",
    "            #计算两个value\n",
    "            value1 = model_value1(state).gather(dim=1, index=action)\n",
    "            value2 = model_value2(state).gather(dim=1, index=action)\n",
    "\n",
    "            #计算两个loss,两个value的目标都是要贴近target\n",
    "            loss_value1 = loss_fn(value1, target)\n",
    "            loss_value2 = loss_fn(value2, target)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer_value1.zero_grad()\n",
    "            loss_value1.backward()\n",
    "            optimizer_value1.step()\n",
    "\n",
    "            optimizer_value2.zero_grad()\n",
    "            loss_value2.backward()\n",
    "            optimizer_value2.step()\n",
    "\n",
    "            #使用model_value计算model_action的loss\n",
    "            loss_action, entropy = get_loss_action(state)\n",
    "            optimizer_action.zero_grad()\n",
    "            loss_action.backward()\n",
    "            optimizer_action.step()\n",
    "\n",
    "            #熵乘以alpha就是alpha的loss\n",
    "            #[b, 1] -> [1]\n",
    "            loss_alpha = (entropy + 1).detach() * alpha.exp()\n",
    "            loss_alpha = loss_alpha.mean()\n",
    "\n",
    "            #更新alpha值\n",
    "            optimizer_alpha.zero_grad()\n",
    "            loss_alpha.backward()\n",
    "            optimizer_alpha.step()\n",
    "\n",
    "            #增量更新next模型\n",
    "            soft_update(model_value1, model_value_next1)\n",
    "            soft_update(model_value2, model_value_next2)\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, len(datas), alpha.exp().item(), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl9klEQVR4nO3df3DU9YH/8deGJCsQdtMAySYlQRQKRAi2gGHP1rNHSoDoyRln1HIQewyMXOIUQimmR0XsjeHw5qr2FP64O/FmjLR0RCsKNgYJZw0/jOT4pTlhaINHNkG57CbRhCT7/v7hl890FTUbEvad+HzMfGay+3l/dt/7nszkObufz8ZljDECAACwSFysJwAAAPBZBAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwTkwD5amnntK1116ra665Rrm5uTp06FAspwMAACwRs0D59a9/rdLSUm3YsEHvvPOOZsyYofz8fDU3N8dqSgAAwBKuWP2zwNzcXM2ePVv/+q//KkkKh8PKzMzUAw88oAcffDAWUwIAAJaIj8WTXrx4UbW1tSorK3Pui4uLU15enmpqaj43vrOzU52dnc7tcDisCxcuaPTo0XK5XFdlzgAA4MoYY9Ta2qqMjAzFxX35hzgxCZQPP/xQPT09SktLi7g/LS1N77333ufGl5eXa+PGjVdregAAYACdPXtW48aN+9IxMQmUaJWVlam0tNS5HQwGlZWVpbNnz8rj8cRwZgAAoLdCoZAyMzM1atSorxwbk0AZM2aMhg0bpqampoj7m5qa5PP5Pjfe7XbL7XZ/7n6Px0OgAAAwyPTm9IyYXMWTmJiomTNnqqqqyrkvHA6rqqpKfr8/FlMCAAAWidlHPKWlpSoqKtKsWbN000036fHHH1d7e7t+9KMfxWpKAADAEjELlLvvvlvnz5/XQw89pEAgoBtvvFF79uz53ImzAADg6ydm34NyJUKhkLxer4LBIOegAAAwSETz95v/xQMAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6/R7oDz88MNyuVwR25QpU5z9HR0dKi4u1ujRo5WUlKTCwkI1NTX19zQAAMAgNiDvoNxwww1qbGx0tjfffNPZt3r1ar388svasWOHqqurde7cOd15550DMQ0AADBIxQ/Ig8bHy+fzfe7+YDCof//3f1dFRYX+6q/+SpL0zDPPaOrUqTpw4IDmzJkzENMBAACDzIC8g/L+++8rIyND1113nRYvXqyGhgZJUm1trbq6upSXl+eMnTJlirKyslRTU/OFj9fZ2alQKBSxAQCAoavfAyU3N1fbtm3Tnj17tGXLFp05c0bf+9731NraqkAgoMTERCUnJ0cck5aWpkAg8IWPWV5eLq/X62yZmZn9PW0AAGCRfv+IZ8GCBc7POTk5ys3N1fjx4/Wb3/xGw4cP79NjlpWVqbS01LkdCoWIFAAAhrABv8w4OTlZ3/rWt3Tq1Cn5fD5dvHhRLS0tEWOampoue87KJW63Wx6PJ2IDAABD14AHSltbm06fPq309HTNnDlTCQkJqqqqcvbX19eroaFBfr9/oKcCAAAGiX7/iOcnP/mJbr/9do0fP17nzp3Thg0bNGzYMN17773yer1atmyZSktLlZKSIo/HowceeEB+v58reAAAgKPfA+WDDz7Qvffeq48++khjx47Vd7/7XR04cEBjx46VJP3yl79UXFycCgsL1dnZqfz8fD399NP9PQ0AADCIuYwxJtaTiFYoFJLX61UwGOR8FAAABolo/n7zv3gAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWCfqQNm/f79uv/12ZWRkyOVy6cUXX4zYb4zRQw89pPT0dA0fPlx5eXl6//33I8ZcuHBBixcvlsfjUXJyspYtW6a2trYreiEAAGDoiDpQ2tvbNWPGDD311FOX3b9582Y9+eST2rp1qw4ePKiRI0cqPz9fHR0dzpjFixfrxIkTqqys1K5du7R//36tWLGi768CAAAMKS5jjOnzwS6Xdu7cqUWLFkn69N2TjIwMrVmzRj/5yU8kScFgUGlpadq2bZvuuecevfvuu8rOztbhw4c1a9YsSdKePXu0cOFCffDBB8rIyPjK5w2FQvJ6vQoGg/J4PH2dPgAAuIqi+fvdr+egnDlzRoFAQHl5ec59Xq9Xubm5qqmpkSTV1NQoOTnZiRNJysvLU1xcnA4ePHjZx+3s7FQoFIrYAADA0NWvgRIIBCRJaWlpEfenpaU5+wKBgFJTUyP2x8fHKyUlxRnzWeXl5fJ6vc6WmZnZn9MGAACWGRRX8ZSVlSkYDDrb2bNnYz0lAAAwgPo1UHw+nySpqakp4v6mpiZnn8/nU3Nzc8T+7u5uXbhwwRnzWW63Wx6PJ2IDAABDV78GyoQJE+Tz+VRVVeXcFwqFdPDgQfn9fkmS3+9XS0uLamtrnTF79+5VOBxWbm5uf04HAAAMUvHRHtDW1qZTp045t8+cOaO6ujqlpKQoKytLq1at0j/+4z9q0qRJmjBhgn7+858rIyPDudJn6tSpmj9/vpYvX66tW7eqq6tLJSUluueee3p1BQ8AABj6og6Ut99+W9///ved26WlpZKkoqIibdu2TT/96U/V3t6uFStWqKWlRd/97ne1Z88eXXPNNc4xzz33nEpKSjR37lzFxcWpsLBQTz75ZD+8HAAAMBRc0fegxArfgwIAwOATs+9BAQAA6A8ECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwTtSBsn//ft1+++3KyMiQy+XSiy++GLH/vvvuk8vlitjmz58fMebChQtavHixPB6PkpOTtWzZMrW1tV3RCwEAAENH1IHS3t6uGTNm6KmnnvrCMfPnz1djY6OzPf/88xH7Fy9erBMnTqiyslK7du3S/v37tWLFiuhnDwAAhqT4aA9YsGCBFixY8KVj3G63fD7fZfe9++672rNnjw4fPqxZs2ZJkn71q19p4cKF+ud//mdlZGREOyUAADDEDMg5KPv27VNqaqomT56slStX6qOPPnL21dTUKDk52YkTScrLy1NcXJwOHjx42cfr7OxUKBSK2AAAwNDV74Eyf/58/ed//qeqqqr0T//0T6qurtaCBQvU09MjSQoEAkpNTY04Jj4+XikpKQoEApd9zPLycnm9XmfLzMzs72kDAACLRP0Rz1e55557nJ+nT5+unJwcXX/99dq3b5/mzp3bp8csKytTaWmpczsUChEpAAAMYQN+mfF1112nMWPG6NSpU5Ikn8+n5ubmiDHd3d26cOHCF5634na75fF4IjYAADB0DXigfPDBB/roo4+Unp4uSfL7/WppaVFtba0zZu/evQqHw8rNzR3o6QAAgEEg6o942tranHdDJOnMmTOqq6tTSkqKUlJStHHjRhUWFsrn8+n06dP66U9/qokTJyo/P1+SNHXqVM2fP1/Lly/X1q1b1dXVpZKSEt1zzz1cwQMAACRJLmOMieaAffv26fvf//7n7i8qKtKWLVu0aNEiHTlyRC0tLcrIyNC8efP0i1/8Qmlpac7YCxcuqKSkRC+//LLi4uJUWFioJ598UklJSb2aQygUktfrVTAY5OMeAAAGiWj+fkcdKDYgUAAAGHyi+fvN/+IBAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHWiCpTy8nLNnj1bo0aNUmpqqhYtWqT6+vqIMR0dHSouLtbo0aOVlJSkwsJCNTU1RYxpaGhQQUGBRowYodTUVK1du1bd3d1X/moAAMCQEFWgVFdXq7i4WAcOHFBlZaW6uro0b948tbe3O2NWr16tl19+WTt27FB1dbXOnTunO++809nf09OjgoICXbx4UW+99ZaeffZZbdu2TQ899FD/vSoAADCouYwxpq8Hnz9/XqmpqaqurtYtt9yiYDCosWPHqqKiQnfddZck6b333tPUqVNVU1OjOXPmaPfu3brtttt07tw5paWlSZK2bt2qdevW6fz580pMTPzK5w2FQvJ6vQoGg/J4PH2dPgAAuIqi+ft9ReegBINBSVJKSookqba2Vl1dXcrLy3PGTJkyRVlZWaqpqZEk1dTUaPr06U6cSFJ+fr5CoZBOnDhx2efp7OxUKBSK2AAAwNDV50AJh8NatWqVbr75Zk2bNk2SFAgElJiYqOTk5IixaWlpCgQCzpg/j5NL+y/tu5zy8nJ5vV5ny8zM7Ou0AQDAINDnQCkuLtbx48e1ffv2/pzPZZWVlSkYDDrb2bNnB/w5AQBA7MT35aCSkhLt2rVL+/fv17hx45z7fT6fLl68qJaWloh3UZqamuTz+Zwxhw4dini8S1f5XBrzWW63W263uy9TBQAAg1BU76AYY1RSUqKdO3dq7969mjBhQsT+mTNnKiEhQVVVVc599fX1amhokN/vlyT5/X4dO3ZMzc3NzpjKykp5PB5lZ2dfyWsBAABDRFTvoBQXF6uiokIvvfSSRo0a5Zwz4vV6NXz4cHm9Xi1btkylpaVKSUmRx+PRAw88IL/frzlz5kiS5s2bp+zsbC1ZskSbN29WIBDQ+vXrVVxczLskAABAUpSXGbtcrsve/8wzz+i+++6T9OkXta1Zs0bPP/+8Ojs7lZ+fr6effjri45s//elPWrlypfbt26eRI0eqqKhImzZtUnx873qJy4wBABh8ovn7fUXfgxIrBAoAAIPPVfseFAAAgIFAoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADgqvjo/YO9HhtVoJSXl2v27NkaNWqUUlNTtWjRItXX10eMufXWW+VyuSK2+++/P2JMQ0ODCgoKNGLECKWmpmrt2rXq7u6OZioAAGCQaTyyp9dj46N54OrqahUXF2v27Nnq7u7Wz372M82bN08nT57UyJEjnXHLly/XI4884tweMWKE83NPT48KCgrk8/n01ltvqbGxUUuXLlVCQoIeffTRaKYDAACGqKgCZc+eyPLZtm2bUlNTVVtbq1tuucW5f8SIEfL5fJd9jN///vc6efKkXn/9daWlpenGG2/UL37xC61bt04PP/ywEhMT+/AyAADAUHJF56AEg0FJUkpKSsT9zz33nMaMGaNp06aprKxMH3/8sbOvpqZG06dPV1pamnNffn6+QqGQTpw4cdnn6ezsVCgUitgAAMDQFdU7KH8uHA5r1apVuvnmmzVt2jTn/h/+8IcaP368MjIydPToUa1bt0719fV64YUXJEmBQCAiTiQ5twOBwGWfq7y8XBs3buzrVAEAQIyFe7pkZHo9vs+BUlxcrOPHj+vNN9+MuH/FihXOz9OnT1d6errmzp2r06dP6/rrr+/Tc5WVlam0tNS5HQqFlJmZ2beJAwCAq66742PJhHs9vk8f8ZSUlGjXrl164403NG7cuC8dm5ubK0k6deqUJMnn86mpqSlizKXbX3TeitvtlsfjidgAAMDg0d3ZJjNQgWKMUUlJiXbu3Km9e/dqwoQJX3lMXV2dJCk9PV2S5Pf7dezYMTU3NztjKisr5fF4lJ2dHc10AADAINHd0S4T7n2gRPURT3FxsSoqKvTSSy9p1KhRzjkjXq9Xw4cP1+nTp1VRUaGFCxdq9OjROnr0qFavXq1bbrlFOTk5kqR58+YpOztbS5Ys0ebNmxUIBLR+/XoVFxfL7XZHMx0AADBI9HS2D9w7KFu2bFEwGNStt96q9PR0Z/v1r38tSUpMTNTrr7+uefPmacqUKVqzZo0KCwv18ssvO48xbNgw7dq1S8OGDZPf79ff/u3faunSpRHfmwIAAIaW9uY/Ktx9sdfjo3oHxZgvP/s2MzNT1dXVX/k448eP16uvvhrNUwMAgEHsk5ZGmZ7ef2s8/4sHAABYh0ABAADWIVAAAMCA+vTk2N5/SZtEoAAAgAEW7u6S6emJ6hgCBQAADKierg6Fe7qiOoZAAQAAAyrc1RHVJcYSgQIAAAZYz0XeQQEAAJbpbP1Q3Z+0RnUMgQIAAAZUR0tAXR8HozqGQAEAANYhUAAAgHUIFAAAYB0CBQAADBgTDsuEo/uSNolAAQAAAyjc062eix1RH0egAACAAWN6utTTRaAAAACLfPoOyidRH0egAACAAdNzsV0XWz+K+jgCBQAADJiLbf+n9vN/jPo4AgUAAFgnPtYTAAAA9uru7r6i43t6or/EWCJQAADAl5g8ebIaGhr6fHzu1G/ql8Xzoj6OQAEAAF+ou7v7it5FCffhS9okzkEBAAADxOWSRrgTnNsXLqb3+ljeQQEAAAMizuWSN+kaSdKpj7+jU+0ZvT6WQAEAAAMiLs4l78hrdOaT6Tr98Y36JNz7b5QlUAAAwIAYFhennsSJeq99jiRXVMdyDgoAABgQCfFxmpSZomjjRCJQAADAAHEnxOs73+r9eSd/jkABAADWIVAAAMCAGZ1wThNHvC2XwlEdF1WgbNmyRTk5OfJ4PPJ4PPL7/dq9e7ezv6OjQ8XFxRo9erSSkpJUWFiopqamiMdoaGhQQUGBRowYodTUVK1du/aKv0YXAADYKc4V1sTh7+ja4cfkdn3c6+Oiuopn3Lhx2rRpkyZNmiRjjJ599lndcccdOnLkiG644QatXr1ar7zyinbs2CGv16uSkhLdeeed+sMf/iDp0+/jLygokM/n01tvvaXGxkYtXbpUCQkJevTRR6N7xQAAwGodF7v14pvv/f9b7+lsW++/qM1ljDFX8uQpKSl67LHHdNddd2ns2LGqqKjQXXfd9elU3ntPU6dOVU1NjebMmaPdu3frtttu07lz55SWliZJ2rp1q9atW6fz588rMTGxV88ZCoXk9Xp133339foYAAAQvYqKCrW1tfXrYwaDQXk8ni8d0+fvQenp6dGOHTvU3t4uv9+v2tpadXV1KS8vzxkzZcoUZWVlOYFSU1Oj6dOnO3EiSfn5+Vq5cqVOnDihb3/725d9rs7OTnV2djq3Q6GQJGnJkiVKSkrq60sAAABf4Xe/+12/B0pvRB0ox44dk9/vV0dHh5KSkrRz505lZ2errq5OiYmJSk5OjhiflpamQCAgSQoEAhFxcmn/pX1fpLy8XBs3bvzc/bNmzfrKAgMAAH0Xq08qor6KZ/Lkyaqrq9PBgwe1cuVKFRUV6eTJkwMxN0dZWZmCwaCznT17dkCfDwAAxFbU76AkJiZq4sSJkqSZM2fq8OHDeuKJJ3T33Xfr4sWLamlpiXgXpampST6fT5Lk8/l06NChiMe7dJXPpTGX43a75Xa7o50qAAAYpK74e1DC4bA6Ozs1c+ZMJSQkqKqqytlXX1+vhoYG+f1+SZLf79exY8fU3NzsjKmsrJTH41F2dvaVTgUAAAwRUb2DUlZWpgULFigrK0utra2qqKjQvn379Nprr8nr9WrZsmUqLS1VSkqKPB6PHnjgAfn9fs2ZM0eSNG/ePGVnZ2vJkiXavHmzAoGA1q9fr+LiYt4hAQAAjqgCpbm5WUuXLlVjY6O8Xq9ycnL02muv6Qc/+IEk6Ze//KXi4uJUWFiozs5O5efn6+mnn3aOHzZsmHbt2qWVK1fK7/dr5MiRKioq0iOPPNK/rwoAAAxqV/w9KLFw6XtQenMdNQAA6Lvx48eroaGhXx+zN3+/+V88AADAOgQKAACwDoECAACsQ6AAAADr9Pl/8QAAgKEvPz9f58+f75fH6urq0iuvvNKrsVzFAwAAropo/n7zEQ8AALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6UQXKli1blJOTI4/HI4/HI7/fr927dzv7b731Vrlcrojt/vvvj3iMhoYGFRQUaMSIEUpNTdXatWvV3d3dP68GAAAMCfHRDB43bpw2bdqkSZMmyRijZ599VnfccYeOHDmiG264QZK0fPlyPfLII84xI0aMcH7u6elRQUGBfD6f3nrrLTU2Nmrp0qVKSEjQo48+2k8vCQAADHYuY4y5kgdISUnRY489pmXLlunWW2/VjTfeqMcff/yyY3fv3q3bbrtN586dU1pamiRp69atWrdunc6fP6/ExMRePWcoFJLX61UwGJTH47mS6QMAgKskmr/ffT4HpaenR9u3b1d7e7v8fr9z/3PPPacxY8Zo2rRpKisr08cff+zsq6mp0fTp0504kaT8/HyFQiGdOHHiC5+rs7NToVAoYgMAAENXVB/xSNKxY8fk9/vV0dGhpKQk7dy5U9nZ2ZKkH/7whxo/frwyMjJ09OhRrVu3TvX19XrhhRckSYFAICJOJDm3A4HAFz5neXm5Nm7cGO1UAQDAIBV1oEyePFl1dXUKBoP67W9/q6KiIlVXVys7O1srVqxwxk2fPl3p6emaO3euTp8+reuvv77PkywrK1NpaalzOxQKKTMzs8+PBwAA7Bb1RzyJiYmaOHGiZs6cqfLycs2YMUNPPPHEZcfm5uZKkk6dOiVJ8vl8ampqihhz6bbP5/vC53S73c6VQ5c2AAAwdF3x96CEw2F1dnZedl9dXZ0kKT09XZLk9/t17NgxNTc3O2MqKyvl8Xicj4kAAACi+oinrKxMCxYsUFZWllpbW1VRUaF9+/bptdde0+nTp1VRUaGFCxdq9OjROnr0qFavXq1bbrlFOTk5kqR58+YpOztbS5Ys0ebNmxUIBLR+/XoVFxfL7XYPyAsEAACDT1SB0tzcrKVLl6qxsVFer1c5OTl67bXX9IMf/EBnz57V66+/rscff1zt7e3KzMxUYWGh1q9f7xw/bNgw7dq1SytXrpTf79fIkSNVVFQU8b0pAAAAV/w9KLHA96AAADD4XJXvQQEAABgoBAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOvGxnkBfGGMkSaFQKMYzAQAAvXXp7/alv+NfZlAGSmtrqyQpMzMzxjMBAADRam1tldfr/dIxLtObjLFMOBxWfX29srOzdfbsWXk8nlhPadAKhULKzMxkHfsBa9l/WMv+wTr2H9ayfxhj1NraqoyMDMXFfflZJoPyHZS4uDh985vflCR5PB5+WfoB69h/WMv+w1r2D9ax/7CWV+6r3jm5hJNkAQCAdQgUAABgnUEbKG63Wxs2bJDb7Y71VAY11rH/sJb9h7XsH6xj/2Etr75BeZIsAAAY2gbtOygAAGDoIlAAAIB1CBQAAGAdAgUAAFhnUAbKU089pWuvvVbXXHONcnNzdejQoVhPyTr79+/X7bffroyMDLlcLr344osR+40xeuihh5Senq7hw4crLy9P77//fsSYCxcuaPHixfJ4PEpOTtayZcvU1tZ2FV9F7JWXl2v27NkaNWqUUlNTtWjRItXX10eM6ejoUHFxsUaPHq2kpCQVFhaqqakpYkxDQ4MKCgo0YsQIpaamau3ateru7r6aLyWmtmzZopycHOdLrvx+v3bv3u3sZw37btOmTXK5XFq1apVzH+vZOw8//LBcLlfENmXKFGc/6xhjZpDZvn27SUxMNP/xH/9hTpw4YZYvX26Sk5NNU1NTrKdmlVdffdX8wz/8g3nhhReMJLNz586I/Zs2bTJer9e8+OKL5r//+7/NX//1X5sJEyaYTz75xBkzf/58M2PGDHPgwAHzX//1X2bixInm3nvvvcqvJLby8/PNM888Y44fP27q6urMwoULTVZWlmlra3PG3H///SYzM9NUVVWZt99+28yZM8f8xV/8hbO/u7vbTJs2zeTl5ZkjR46YV1991YwZM8aUlZXF4iXFxO9+9zvzyiuvmP/5n/8x9fX15mc/+5lJSEgwx48fN8awhn116NAhc+2115qcnBzz4x//2Lmf9eydDRs2mBtuuME0NjY62/nz5539rGNsDbpAuemmm0xxcbFzu6enx2RkZJjy8vIYzspunw2UcDhsfD6feeyxx5z7WlpajNvtNs8//7wxxpiTJ08aSebw4cPOmN27dxuXy2X+93//96rN3TbNzc1GkqmurjbGfLpuCQkJZseOHc6Yd99910gyNTU1xphPYzEuLs4EAgFnzJYtW4zH4zGdnZ1X9wVY5Bvf+Ib5t3/7N9awj1pbW82kSZNMZWWl+cu//EsnUFjP3tuwYYOZMWPGZfexjrE3qD7iuXjxompra5WXl+fcFxcXp7y8PNXU1MRwZoPLmTNnFAgEItbR6/UqNzfXWceamholJydr1qxZzpi8vDzFxcXp4MGDV33OtggGg5KklJQUSVJtba26uroi1nLKlCnKysqKWMvp06crLS3NGZOfn69QKKQTJ05cxdnboaenR9u3b1d7e7v8fj9r2EfFxcUqKCiIWDeJ38lovf/++8rIyNB1112nxYsXq6GhQRLraINB9c8CP/zwQ/X09ET8MkhSWlqa3nvvvRjNavAJBAKSdNl1vLQvEAgoNTU1Yn98fLxSUlKcMV834XBYq1at0s0336xp06ZJ+nSdEhMTlZycHDH2s2t5ubW+tO/r4tixY/L7/ero6FBSUpJ27typ7Oxs1dXVsYZR2r59u9555x0dPnz4c/v4ney93Nxcbdu2TZMnT1ZjY6M2btyo733vezp+/DjraIFBFShALBUXF+v48eN68803Yz2VQWny5Mmqq6tTMBjUb3/7WxUVFam6ujrW0xp0zp49qx//+MeqrKzUNddcE+vpDGoLFixwfs7JyVFubq7Gjx+v3/zmNxo+fHgMZwZpkF3FM2bMGA0bNuxzZ1E3NTXJ5/PFaFaDz6W1+rJ19Pl8am5ujtjf3d2tCxcufC3XuqSkRLt27dIbb7yhcePGOff7fD5dvHhRLS0tEeM/u5aXW+tL+74uEhMTNXHiRM2cOVPl5eWaMWOGnnjiCdYwSrW1tWpubtZ3vvMdxcfHKz4+XtXV1XryyScVHx+vtLQ01rOPkpOT9a1vfUunTp3i99ICgypQEhMTNXPmTFVVVTn3hcNhVVVVye/3x3Bmg8uECRPk8/ki1jEUCungwYPOOvr9frW0tKi2ttYZs3fvXoXDYeXm5l71OceKMUYlJSXauXOn9u7dqwkTJkTsnzlzphISEiLWsr6+Xg0NDRFreezYsYjgq6yslMfjUXZ29tV5IRYKh8Pq7OxkDaM0d+5cHTt2THV1dc42a9YsLV682PmZ9eybtrY2nT59Wunp6fxe2iDWZ+lGa/v27cbtdptt27aZkydPmhUrVpjk5OSIs6jx6Rn+R44cMUeOHDGSzL/8y7+YI0eOmD/96U/GmE8vM05OTjYvvfSSOXr0qLnjjjsue5nxt7/9bXPw4EHz5ptvmkmTJn3tLjNeuXKl8Xq9Zt++fRGXIn788cfOmPvvv99kZWWZvXv3mrffftv4/X7j9/ud/ZcuRZw3b56pq6sze/bsMWPHjv1aXYr44IMPmurqanPmzBlz9OhR8+CDDxqXy2V+//vfG2NYwyv151fxGMN69taaNWvMvn37zJkzZ8wf/vAHk5eXZ8aMGWOam5uNMaxjrA26QDHGmF/96lcmKyvLJCYmmptuuskcOHAg1lOyzhtvvGEkfW4rKioyxnx6qfHPf/5zk5aWZtxut5k7d66pr6+PeIyPPvrI3HvvvSYpKcl4PB7zox/9yLS2tsbg1cTO5dZQknnmmWecMZ988on5+7//e/ONb3zDjBgxwvzN3/yNaWxsjHicP/7xj2bBggVm+PDhZsyYMWbNmjWmq6vrKr+a2Pm7v/s7M378eJOYmGjGjh1r5s6d68SJMazhlfpsoLCevXP33Xeb9PR0k5iYaL75zW+au+++25w6dcrZzzrGlssYY2Lz3g0AAMDlDapzUAAAwNcDgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6/w+5ZA0jCLWsqgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "171.0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
